% Compute the test statistics and bootstrap p-values for the penalized Bierens-max test 
% proposed in Chen et al. (2025) for full-vector inference.
% Utilize the grid search method.

% Arguments:
% datau must be an n*K matrix of generalized residuals, with each column corresponding to a distinct hypothesized value for theta1.
% dataw is an n*p matrix of instruments not including the constant.
% bnd defines [-bnd, bnd]^p as the search space from which the grid is sampled.
% lambdaset is the search space for optimal penalty level

% Ouputs:
% bs_pv is a |lambdaset| by 1 vector of bootstrap p-values
% rejection is a |lambdaset| by 1 vector of 0s and 1s, indicating the rejection under each value of lambdas.
% runtime is the computation time
% teststat is a |lambdaset| by 1 vector of test statistics
% gamma_hat is a K by |lambdaset| by p array of gamma_hat, defined as the argmax of the objective function.

function output = pen_Bierens_test_GS(datau, dataw, bnd, lambdaset, varargin)
ip = inputParser;

addParameter(ip, 'bR', 500);  % Number of bootstrap replications
addParameter(ip, 'alpha_level', 0.1);  % nominal level of the test
addParameter(ip, 'sz_Gamma', 10000)  % Number of grid points, sampled uniformly from [-bnd, bnd]^p
addParameter(ip, 'splits', 10)
addParameter(ip, 'B', [])  % Local alternatives
addParameter(ip, 'gradient', [])  % n by dim(theta) matrix, indicating the direction of local deviations of the score
addParameter(ip, 'seed', [])
addParameter(ip, 'dm', 0)  % indicator for exponential weight demeaning.
parse(ip, varargin{:});
bR = ip.Results.bR;
alpha_level = ip.Results.alpha_level;
sz_Gamma = ip.Results.sz_Gamma;
B = ip.Results.B;
gradient = ip.Results.gradient;
splits = ip.Results.splits;
seed = ip.Results.seed;
dm = ip.Results.dm;

% Set seed
if ~isequal(seed, [])
  rng(seed, 'twister')
end
J = length(lambdaset);
lambdaset = sort(lambdaset, 'descend');

% Studentize the instruments, apply a bounded and strictly monotone transform, renormalize
dataw = zscore(dataw);
dataw = atan(dataw);
%--------------------------------------------------------------------------
% Test statistics & Multiplier bootstrap
%--------------------------------------------------------------------------
n = size(datau, 1);
K = size(datau, 2);
p = size(dataw, 2);

% Reshape B as a 1 by k row vector
if ~isequal(B, [])
  B = reshape(B, 1, []);
end

% Each column represents {gradient * b: i=1,...,n} for each b in B
Local_Alts = [];
for i = 1:length(B)
  Local_Alts = [Local_Alts, repmat(gradient*B(i)/sqrt(n), 1, bR)];
end

% Each loop adds [generalized residuals, bootstrap residuals, bootstrap residuals + local deviations] to ustar
ustar = [];
if isequal(B, [])
  Eta = randn(n, bR, 1);
  for k = 1:K
    ustar = [ustar, datau(:, k), datau(:, k) .* Eta(:, :, 1)];
  end
else
  Eta = randn(n, bR, 1);
  for k = 1:K
    ustar = [ustar, datau(:, k), datau(:, k) .* Eta(:, :, 1), repmat(datau(:, k) .* Eta(:, :, 1), 1, length(B)) + Local_Alts];
  end
end
clear Eta

cycle = size(ustar, 2) / K;
stat = zeros(size(ustar, 2), J);
Tstart = tic();
for s = 1:splits
  R_split = ceil(sz_Gamma / splits);
  if s == 1
    Gamma = [bnd*(2*num2grid(round(sz_Gamma)^(1/p)*ones(p, 1), (R_split*(s-1)+1):R_split*s)-1); unifrnd(-0.001,0.001,[500, p])];  % add slight random errors to the origin at which the objective function is discontinuous
  else
    Gamma = bnd*(2*num2grid(round(sz_Gamma)^(1/p)*ones(p, 1), (R_split*(s-1)+1):R_split*s)-1);
  end
  [~, o] = hdm(Gamma, ustar, dataw, 0, 'dm', dm);
  for ilambda = 1:J
    q = o.Q + lambdaset(ilambda) * repmat(o.l1, size(ustar, 2), 1);  % Compute the objective function of the sample and the bootstrap sample
                                                                     % across different values of gamma from Gamma
    stat(:, ilambda) = min(stat(:, ilambda), min(q, [], 2));
  end
end
% Results in (cycle*K) by J matrix
stat = -stat;
bootstrap_pv = zeros(J, K);

% Compute bootstrap p-values
for k = 1:K
  for ilambda = 1:J
      bootstrap_pv(ilambda, k) = mean(stat((2 + (k-1)*cycle):(1 + (k-1)*cycle + bR), ilambda) >= stat(1 + (k-1)*cycle, ilambda));
  end
end
ind = 1:cycle:(1+(K-1)*cycle);
output.bs_pv = bootstrap_pv;
output.rejection = (bootstrap_pv < alpha_level);
output.teststat = stat(ind, :)';
%--------------------------------------------------------------------------
% Select optimal penalty level
%--------------------------------------------------------------------------
if isequal(B, [])
elseif ~isequal(B, [])
  bs_cv = zeros(K, J);
  local_powers = zeros(length(B), K, J);
  for k = 1:K
    for ilambda = 1:J
      bs_cv(k, ilambda) = quantile(stat((2 + (k-1)*cycle):(1 + (k-1)*cycle + bR), ilambda), 1 - alpha_level);
      for j = 1:length(B)
        local_powers(j, k, J+1 - ilambda) = mean(stat((2 + j*bR + (k-1)*cycle):(1 + (j+1)*bR + (k-1)*cycle), ilambda) > bs_cv(k, ilambda));
      end
    end
  end
  output.local_powers = local_powers;
    obj = reshape(min(local_powers, [], 1), [K, J]);  % Solve local power minimax problem
  [~, lambda_ind] = max(obj, [], 2);
  opt_pv = zeros(1, K);
  for k = 1:K
    opt_pv(1, k) = bootstrap_pv(J+1 - lambda_ind(k), k);
  end
  output.max_Bierens_pv = bootstrap_pv(end, :);
  output.opt_pv = opt_pv;
  output.opt_lambda = lambdaset(J+1 - lambda_ind);
end
Tend = toc(Tstart);
output.runtime = Tend;
end